#!/usr/bin/env python3
"""
Phase 8: Diagnostic Probes
A. Mongolia robustness — re-run key results excluding MNG
B. Income balance fragility — test CCA tipping on income_balance_gdp & trade_balance_gdp
C. CCA entry timing — when do CCA countries enter the panel?
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats as sp_stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

CCA_DIR = PROJECT_DIR / "cca_tipping"
PROCESSED_DIR = CCA_DIR / "data" / "processed"
TABLE_DIR = CCA_DIR / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

# ── Country Sets ───────────────────────────────────────────────────────────
CCA_ALL = {'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA', 'MNG', 'RUS',
           'TJK', 'TKM', 'UKR', 'UZB'}
CCA_NON_COMMODITY = {'ARM', 'BLR', 'GEO', 'KGZ', 'MDA', 'MNG', 'TJK', 'UKR'}
CCA_NC_NO_MNG = CCA_NON_COMMODITY - {'MNG'}  # 7 countries without Mongolia

def star(p):
    if p < 0.01: return "***"
    elif p < 0.05: return "**"
    elif p < 0.10: return "*"
    return ""

def run_gls(df, dep_var, indep_vars):
    comp = df.dropna(subset=[dep_var] + indep_vars).copy()
    if len(comp) < 50:
        return None
    try:
        y = comp[dep_var].values
        X = comp[indep_vars].values
        gls = PanelGLS()
        gls.fit(y, X, comp['iso3'].values, comp['year'].values)
        resid = y - X @ gls.beta
        return {
            'r_squared': gls.r_squared,
            'n_obs': gls.n_obs,
            'n_countries': gls.n_countries,
            'coefficients': dict(zip(indep_vars, gls.beta)),
            'std_errors': dict(zip(indep_vars, gls.se)),
            'p_values': dict(zip(indep_vars, gls.pvalues)),
            'ssr': float(np.sum(resid ** 2)),
        }
    except Exception as e:
        print(f"    GLS error: {e}")
        return None


# ══════════════════════════════════════════════════════════════════════════
# LOAD DATA
# ══════════════════════════════════════════════════════════════════════════
panel = pd.read_csv(PROCESSED_DIR / "cca_panel.csv")
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                     'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
base_vars = demo_vars + baseline_controls
base_sample = est.dropna(subset=base_vars + ['ca_gdp']).copy()

print("=" * 80)
print("PHASE 8: DIAGNOSTIC PROBES")
print("=" * 80)

# ══════════════════════════════════════════════════════════════════════════
# PROBE A: MONGOLIA ROBUSTNESS
# ══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 80)
print("PROBE A: MONGOLIA ROBUSTNESS")
print("=" * 80)

# A1: Basic tipping point excluding Mongolia
excl_mng = base_sample[base_sample['iso3'] != 'MNG']
excl_mng_and_cca_nc = base_sample[~base_sample['iso3'].isin(CCA_NC_NO_MNG)]  # drop 7 CCA-nc (no MNG)
excl_cca_nc_with_mng = base_sample[~base_sample['iso3'].isin(CCA_NON_COMMODITY)]  # original: drop all 8

print("\n--- A1: Tipping point decomposition ---")
specs = [
    ("Full sample", base_sample),
    ("Excl Mongolia only", excl_mng),
    ("Excl CCA-nc (all 8)", base_sample[~base_sample['iso3'].isin(CCA_NON_COMMODITY)]),
    ("Excl CCA-nc minus MNG (7)", base_sample[~base_sample['iso3'].isin(CCA_NC_NO_MNG)]),
    ("Excl MNG + CCA-nc-7 (all 8 a different way)", base_sample[~base_sample['iso3'].isin(CCA_NON_COMMODITY)]),
]

a1_rows = []
for name, sub in [
    ("Full sample (138c)", base_sample),
    ("Excl MNG only (137c)", base_sample[base_sample['iso3'] != 'MNG']),
    ("Excl CCA-nc all 8 (130c)", base_sample[~base_sample['iso3'].isin(CCA_NON_COMMODITY)]),
    ("Excl CCA-nc minus MNG (7, 131c)", base_sample[~base_sample['iso3'].isin(CCA_NC_NO_MNG)]),
    ("Excl MNG + commodity CCA (6, 132c)", base_sample[~base_sample['iso3'].isin({'MNG'} | (CCA_ALL - CCA_NON_COMMODITY))]),
]:
    r = run_gls(sub, 'ca_gdp', base_vars)
    if r:
        row = {'sample': name, 'n_c': r['n_countries'], 'n_obs': r['n_obs'],
               'Z_1': r['coefficients']['Z_1'], 'Z_1_p': r['p_values']['Z_1'],
               'Z_1_se': r['std_errors']['Z_1'], 'r2': r['r_squared']}
        a1_rows.append(row)
        print(f"  {name:<45s}: Z₁={row['Z_1']:7.2f}{star(row['Z_1_p'])} (p={row['Z_1_p']:.4f}), R²={row['r2']:.4f}")

# A2: Governance quintiles excluding Mongolia
print("\n--- A2: Governance quintiles EXCLUDING Mongolia ---")
gov_sample_no_mng = est[est['iso3'] != 'MNG'].dropna(
    subset=base_vars + ['ca_gdp', 'governance_composite']).copy()
gov_sample_no_mng['gov_quintile'] = pd.qcut(
    gov_sample_no_mng['governance_composite'], 5, labels=False) + 1

for q in sorted(gov_sample_no_mng['gov_quintile'].unique()):
    sub = gov_sample_no_mng[gov_sample_no_mng['gov_quintile'] == q]
    r = run_gls(sub, 'ca_gdp', base_vars)
    if r:
        print(f"  Q{q}: Z₁={r['coefficients']['Z_1']:7.2f}{star(r['p_values']['Z_1'])} "
              f"(p={r['p_values']['Z_1']:.4f}), N_c={r['n_countries']}")
    else:
        print(f"  Q{q}: insufficient obs")

# A3: Hansen threshold excluding Mongolia
print("\n--- A3: Hansen threshold EXCLUDING Mongolia ---")
gov_h_no_mng = est[est['iso3'] != 'MNG'].dropna(
    subset=base_vars + ['ca_gdp', 'governance_composite']).copy()
r_pooled_nm = run_gls(gov_h_no_mng, 'ca_gdp', base_vars)
ssr_pooled_nm = r_pooled_nm['ssr'] if r_pooled_nm else np.inf

gov_sorted = gov_h_no_mng['governance_composite'].dropna().sort_values()
trim = int(0.15 * len(gov_sorted))
threshold_grid = np.percentile(gov_sorted.values[trim:-trim], np.arange(10, 91, 5))

best_ssr = np.inf
best_tau = None
best_below = None
best_above = None

for tau in threshold_grid:
    below = gov_h_no_mng[gov_h_no_mng['governance_composite'] <= tau]
    above = gov_h_no_mng[gov_h_no_mng['governance_composite'] > tau]
    r_b = run_gls(below, 'ca_gdp', base_vars)
    r_a = run_gls(above, 'ca_gdp', base_vars)
    if r_b and r_a:
        total_ssr = r_b['ssr'] + r_a['ssr']
        if total_ssr < best_ssr:
            best_ssr = total_ssr
            best_tau = tau
            best_below = r_b
            best_above = r_a

if best_tau is not None:
    k = len(base_vars)
    n1, n2 = best_below['n_obs'], best_above['n_obs']
    F = ((ssr_pooled_nm - best_ssr) / k) / (best_ssr / (n1 + n2 - 2 * k))
    p_f = 1 - sp_stats.f.cdf(F, k, n1 + n2 - 2 * k)
    print(f"  Optimal threshold (excl MNG): gov = {best_tau:.3f}")
    print(f"  Below ({best_below['n_countries']}c): Z₁={best_below['coefficients']['Z_1']:.2f}"
          f"{star(best_below['p_values']['Z_1'])} (p={best_below['p_values']['Z_1']:.4f})")
    print(f"  Above ({best_above['n_countries']}c): Z₁={best_above['coefficients']['Z_1']:.2f}"
          f"{star(best_above['p_values']['Z_1'])} (p={best_above['p_values']['Z_1']:.4f})")
    print(f"  F={F:.2f}, p={p_f:.4f}")

# A4: Placebo excluding Mongolia
print("\n--- A4: Placebo EXCLUDING Mongolia ---")
base_no_mng = base_sample[base_sample['iso3'] != 'MNG'].copy()
r_full_nm = run_gls(base_no_mng, 'ca_gdp', base_vars)
# Drop CCA-nc minus MNG (7 countries)
sub_excl = base_no_mng[~base_no_mng['iso3'].isin(CCA_NC_NO_MNG)]
r_excl_nm = run_gls(sub_excl, 'ca_gdp', base_vars)

if r_full_nm and r_excl_nm:
    true_change_nm = r_excl_nm['coefficients']['Z_1'] - r_full_nm['coefficients']['Z_1']
    print(f"  True Z₁ change (excl MNG, drop remaining 7 CCA-nc): "
          f"{r_full_nm['coefficients']['Z_1']:.2f} → {r_excl_nm['coefficients']['Z_1']:.2f} (Δ={true_change_nm:.2f})")

    # Placebo: random 7-country drops
    country_gdp = base_no_mng.groupby('iso3')['gdp_pc_ppp'].mean().dropna()
    cca_nc_gdp = country_gdp[country_gdp.index.isin(CCA_NC_NO_MNG)]
    if len(cca_nc_gdp) > 0:
        gdp_lo = cca_nc_gdp.min() * 0.3
        gdp_hi = cca_nc_gdp.max() * 3.0
        pool = country_gdp[(country_gdp >= gdp_lo) & (country_gdp <= gdp_hi) &
                            (~country_gdp.index.isin(CCA_ALL))].index.tolist()
    else:
        pool = [c for c in base_no_mng['iso3'].unique() if c not in CCA_ALL]

    N_BOOT = 500
    np.random.seed(42)
    n_drop = len(CCA_NC_NO_MNG & set(base_no_mng['iso3'].unique()))
    placebo_changes = []
    for _ in range(N_BOOT):
        fake = np.random.choice(pool, size=n_drop, replace=False)
        sub = base_no_mng[~base_no_mng['iso3'].isin(fake)]
        r = run_gls(sub, 'ca_gdp', base_vars)
        if r:
            placebo_changes.append(r['coefficients']['Z_1'] - r_full_nm['coefficients']['Z_1'])

    if placebo_changes:
        arr = np.array(placebo_changes)
        pval = np.mean(np.abs(arr) >= np.abs(true_change_nm))
        print(f"  Placebo ({N_BOOT} draws of {n_drop} countries):")
        print(f"    Mean Δ: {arr.mean():.2f}, Std: {arr.std():.2f}")
        print(f"    True Δ: {true_change_nm:.2f}")
        print(f"    p-value: {pval:.4f}")
        print(f"    {'ANOMALOUS' if pval < 0.05 else 'NOT anomalous'} (excl Mongolia)")

# A5: Mongolia leverage diagnostics
print("\n--- A5: Mongolia leverage diagnostics ---")
mng = base_sample[base_sample['iso3'] == 'MNG']
print(f"  Mongolia: {len(mng)} obs, years {mng['year'].min()}-{mng['year'].max()}")
print(f"  CA/GDP: mean={mng['ca_gdp'].mean():.2f}, std={mng['ca_gdp'].std():.2f}, "
      f"min={mng['ca_gdp'].min():.2f}, max={mng['ca_gdp'].max():.2f}")
print(f"  Z₁:     mean={mng['Z_1'].mean():.4f}, std={mng['Z_1'].std():.4f}")
print(f"  Z₂:     mean={mng['Z_2'].mean():.4f}, std={mng['Z_2'].std():.4f}")
print(f"  KAOPEN:  mean={mng['kaopen'].mean():.2f}")
print(f"  NFA/GDP: mean={mng['nfa_gdp_lag'].mean():.2f}")

# Compare to other CCA-nc countries
for iso in sorted(CCA_NON_COMMODITY):
    sub = base_sample[base_sample['iso3'] == iso]
    if len(sub) > 0:
        print(f"  {iso}: N={len(sub)}, CA mean={sub['ca_gdp'].mean():7.2f}, "
              f"CA std={sub['ca_gdp'].std():6.2f}, Z₁ mean={sub['Z_1'].mean():7.4f}, "
              f"Z₁ std={sub['Z_1'].std():.4f}")


# ══════════════════════════════════════════════════════════════════════════
# PROBE B: INCOME BALANCE & TRADE BALANCE FRAGILITY
# ══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 80)
print("PROBE B: INCOME BALANCE & TRADE BALANCE FRAGILITY")
print("=" * 80)

# Merge income_balance_gdp and trade_balance_gdp from net_gross panel
ng = pd.read_csv(PROJECT_DIR / "net_gross" / "data" / "processed" / "net_gross_panel.csv",
                  usecols=['iso3', 'year', 'income_balance_gdp', 'trade_balance_gdp'])
ng = ng.drop_duplicates(subset=['iso3', 'year'])

est_b = est.merge(ng, on=['iso3', 'year'], how='left')
print(f"  Merged: income_balance_gdp non-null: {est_b['income_balance_gdp'].notna().sum()}")
print(f"  Merged: trade_balance_gdp non-null: {est_b['trade_balance_gdp'].notna().sum()}")

# B1: Full fragility test on each DV
dvs = [
    ('ca_gdp', 'CA/GDP'),
    ('income_balance_gdp', 'Income balance/GDP'),
    ('trade_balance_gdp', 'Trade balance/GDP'),
    ('nfa_gdp', 'NFA/GDP'),
]

print("\n--- B1: CCA-nc fragility across DVs ---")
print(f"  {'DV':<25s} {'Sample':<20s} {'N_c':>4s} {'N':>6s} {'Z₁':>8s} {'p':>8s} {'R²':>7s}")
print("  " + "-" * 80)

b_rows = []
for dv, dv_label in dvs:
    if dv not in est_b.columns:
        continue
    for sample_name, sample_df in [
        ("Full", est_b),
        ("Excl CCA-nc (8)", est_b[~est_b['iso3'].isin(CCA_NON_COMMODITY)]),
        ("Excl MNG only", est_b[est_b['iso3'] != 'MNG']),
        ("Excl CCA-nc excl MNG (7)", est_b[~est_b['iso3'].isin(CCA_NC_NO_MNG)]),
    ]:
        r = run_gls(sample_df, dv, base_vars)
        if r:
            row = {'dv': dv_label, 'sample': sample_name,
                   'n_c': r['n_countries'], 'n_obs': r['n_obs'],
                   'Z_1': r['coefficients']['Z_1'], 'Z_1_se': r['std_errors']['Z_1'],
                   'Z_1_p': r['p_values']['Z_1'], 'r2': r['r_squared']}
            b_rows.append(row)
            print(f"  {dv_label:<25s} {sample_name:<20s} {r['n_countries']:4d} {r['n_obs']:6d} "
                  f"{row['Z_1']:8.2f}{star(row['Z_1_p'])} {row['Z_1_p']:8.4f} {row['r2']:7.4f}")
        else:
            print(f"  {dv_label:<25s} {sample_name:<20s}  —  insufficient data")

# B2: Governance interaction on income balance
print("\n--- B2: Governance interactions on income balance ---")
ib_sample = est_b.dropna(subset=base_vars + ['income_balance_gdp', 'governance_composite']).copy()
if len(ib_sample) > 100:
    ib_sample['gov_dm'] = ib_sample['governance_composite'] - ib_sample['governance_composite'].mean()
    for z in demo_vars:
        ib_sample[f'{z}_x_gov'] = ib_sample[z] * ib_sample['gov_dm']
    gov_int_vars = [f'{z}_x_gov' for z in demo_vars]

    for spec_name, spec_vars in [
        ("Baseline", base_vars),
        ("+ Gov control", base_vars + ['governance_composite']),
        ("+ Gov interaction", base_vars + ['gov_dm'] + gov_int_vars),
        ("+ CCA dummy", base_vars + ['is_cca']),
        ("+ CCA-nc dummy", base_vars + ['is_cca_non_commodity']),
    ]:
        # Need dummies in ib_sample
        ib_sample['is_cca'] = ib_sample['iso3'].isin(CCA_ALL).astype(float)
        ib_sample['is_cca_non_commodity'] = ib_sample['iso3'].isin(CCA_NON_COMMODITY).astype(float)

        r = run_gls(ib_sample, 'income_balance_gdp', spec_vars)
        if r:
            extra = ""
            if 'Z_1_x_gov' in r['coefficients']:
                extra = f", Z₁×gov={r['coefficients']['Z_1_x_gov']:.2f}{star(r['p_values']['Z_1_x_gov'])}"
            if 'is_cca' in r['coefficients']:
                extra += f", CCA={r['coefficients']['is_cca']:.2f}{star(r['p_values']['is_cca'])}"
            if 'is_cca_non_commodity' in r['coefficients']:
                extra += f", CCA-nc={r['coefficients']['is_cca_non_commodity']:.2f}{star(r['p_values']['is_cca_non_commodity'])}"
            print(f"  {spec_name:<25s}: Z₁={r['coefficients']['Z_1']:.2f}{star(r['p_values']['Z_1'])} "
                  f"(p={r['p_values']['Z_1']:.4f}){extra}")

# B3: Income balance quintile splits
print("\n--- B3: Income balance by governance quintile ---")
if len(ib_sample) > 100:
    ib_sample['gov_quintile'] = pd.qcut(ib_sample['governance_composite'], 5, labels=False) + 1
    for q in sorted(ib_sample['gov_quintile'].unique()):
        sub = ib_sample[ib_sample['gov_quintile'] == q]
        r = run_gls(sub, 'income_balance_gdp', base_vars)
        if r:
            print(f"  Q{q}: Z₁={r['coefficients']['Z_1']:7.2f}{star(r['p_values']['Z_1'])} "
                  f"(p={r['p_values']['Z_1']:.4f}), N_c={r['n_countries']}")
        else:
            print(f"  Q{q}: insufficient obs")


# ══════════════════════════════════════════════════════════════════════════
# PROBE C: CCA ENTRY TIMING
# ══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 80)
print("PROBE C: CCA ENTRY TIMING")
print("=" * 80)

for iso in sorted(CCA_ALL):
    sub = base_sample[base_sample['iso3'] == iso]
    if len(sub) > 0:
        print(f"  {iso}: first year={sub['year'].min()}, last={sub['year'].max()}, N={len(sub)}")
    else:
        # Check broader panel
        sub2 = est[est['iso3'] == iso]
        if len(sub2) > 0:
            print(f"  {iso}: in est but not base_sample (missing controls), years={sub2['year'].min()}-{sub2['year'].max()}")
        else:
            print(f"  {iso}: NOT in sample")

# When do most CCA countries enter?
cca_entry = {}
for iso in CCA_ALL:
    sub = base_sample[base_sample['iso3'] == iso]
    if len(sub) > 0:
        cca_entry[iso] = sub['year'].min()

if cca_entry:
    entry_df = pd.Series(cca_entry).sort_values()
    print(f"\n  CCA entry years:")
    for iso, yr in entry_df.items():
        print(f"    {iso}: {yr}")
    print(f"  Median entry: {entry_df.median():.0f}")


# ══════════════════════════════════════════════════════════════════════════
# PROBE D: DFBETAS — FORMAL INFLUENCE MEASURE
# ══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 80)
print("PROBE D: COUNTRY-LEVEL INFLUENCE (DFBETAS-STYLE)")
print("=" * 80)

# For each country, compute: (Z₁_full - Z₁_excl) / SE(Z₁_full)
r_full = run_gls(base_sample, 'ca_gdp', base_vars)
if r_full:
    z1_full = r_full['coefficients']['Z_1']
    se_full = r_full['std_errors']['Z_1']

    influence_rows = []
    for iso in sorted(base_sample['iso3'].unique()):
        sub = base_sample[base_sample['iso3'] != iso]
        r = run_gls(sub, 'ca_gdp', base_vars)
        if r:
            z1_excl = r['coefficients']['Z_1']
            dfbeta = (z1_full - z1_excl) / se_full
            influence_rows.append({
                'iso3': iso,
                'Z_1_excl': z1_excl,
                'dfbeta': dfbeta,
                'n_obs_dropped': len(base_sample[base_sample['iso3'] == iso]),
                'is_cca': iso in CCA_ALL,
                'is_cca_nc': iso in CCA_NON_COMMODITY,
            })

    inf_df = pd.DataFrame(influence_rows)
    inf_df = inf_df.sort_values('dfbeta', key=abs, ascending=False)

    print(f"\n  Top 20 most influential countries (|DFBETA| on Z₁):")
    print(f"  {'ISO':<5s} {'DFBETA':>8s} {'Z₁ excl':>8s} {'N dropped':>9s} {'CCA?':>5s}")
    print("  " + "-" * 40)
    for _, row in inf_df.head(20).iterrows():
        cca_flag = "NC" if row['is_cca_nc'] else ("COM" if row['is_cca'] else "")
        print(f"  {row['iso3']:<5s} {row['dfbeta']:8.3f} {row['Z_1_excl']:8.2f} "
              f"{row['n_obs_dropped']:9.0f} {cca_flag:>5s}")

    # How many countries have |DFBETA| > 2/sqrt(n)?
    cutoff = 2 / np.sqrt(r_full['n_countries'])
    n_influential = (inf_df['dfbeta'].abs() > cutoff).sum()
    print(f"\n  DFBETA cutoff (2/√n = 2/√{r_full['n_countries']}): {cutoff:.4f}")
    print(f"  Countries exceeding cutoff: {n_influential}")
    influential = inf_df[inf_df['dfbeta'].abs() > cutoff]
    for _, row in influential.iterrows():
        cca_flag = "CCA-NC" if row['is_cca_nc'] else ("CCA-COM" if row['is_cca'] else "")
        print(f"    {row['iso3']} (DFBETA={row['dfbeta']:.3f}) {cca_flag}")

    inf_df.to_csv(TABLE_DIR / "probe_dfbetas.csv", index=False)


# ══════════════════════════════════════════════════════════════════════════
# SUMMARY
# ══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 80)
print("SUMMARY OF DIAGNOSTIC PROBES")
print("=" * 80)

# Save all B rows
pd.DataFrame(b_rows).to_csv(TABLE_DIR / "probe_income_balance_fragility.csv", index=False)
pd.DataFrame(a1_rows).to_csv(TABLE_DIR / "probe_mongolia_robustness.csv", index=False)

print("\nDone. Results saved to output/tables/probe_*.csv")
